/*
Copyright (c) 2025 Huawei Technologies Co., Ltd.
This file is a part of the CANN Open Software.
Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
Please refer to the License for details. You may not use this file except in compliance with the License.
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
See LICENSE in the root of the software repository for the full text of the License. 
*/

#pragma once
#include "kernel_const.h"
#include "kernel_utils.h"
#include "kernel_operator.h"


/**
 * @brief 用于 GM内ND行的AB->L1内Nz的AB
 * @param[out] l1_dstLocalTensor: L1内目的地址
 * @param[in] gm_srcGlobalTensor: GM内源地址
 * @param[in] gm_mActual: GM内数据M维度实际大小（无对齐要求。最大为L1_M0, L1_N0, L1K0）
 * @param[in] gm_nActual: GM内数据N维度实际大小
 * @param[in] l1_mActualRound: 搬入L1后数据M维度大小。实际需要的数据大小对基块维度(BASEBLOCK_M0, BASEBLOCK_N0向上对齐。（这里传入的数应按BASEBLOCK_M0或N0对齐，对于half，BASEBLOCK_M0N0K0 3数相等）
 * @param[in] l1_nActualRound: 搬入L1后数据N维度大小。DataCopy Nd2Nz时N维度自动按32B向上对齐。（外面传入的应该是对BASEBLOCK_K0对齐过的数）
 */

[aicore] inline __attribute__((always_inline)) void Gm2L1_Nd2Nz(
    AscendC::LocalTensor<half> l1_dstLocalTensor,
    AscendC::GlobalTensor<half> gm_srcGlobalTensor,
    uint32_t gm_mActual, 
    uint32_t gm_nActual,
    uint32_t gm_srcStride,
    uint32_t l1_mActualRound, 
    uint32_t l1_nActualRound 
){

    uint32_t datablockSize = DATABLOCK_BYTES / sizeof(half);

    if( gm_srcStride < STRIDE_LIMIT ){
        AscendC::DataCopy(
            l1_dstLocalTensor, 
            gm_srcGlobalTensor, 
            AscendC::Nd2NzParams(
                1, 
                gm_mActual, 
                gm_nActual, 
                0, 
                gm_srcStride, 
                l1_mActualRound, 
                1, 
                1 
            )
        );
    } else {
        for(uint64_t i = 0; i < gm_mActual; i++ ){
            AscendC::DataCopy(
                l1_dstLocalTensor[i*datablockSize], 
                gm_srcGlobalTensor[i*gm_srcStride],
                AscendC::Nd2NzParams(
                    1, 
                    1, 
                    gm_nActual, 
                    0, 
                    gm_srcStride, 
                    l1_mActualRound, 
                    0, 
                    1 
                )  
            );
        }
    }
}

/**
 * @brief 用于 L1内Nz的A->L0A内Zz的A
 * @param[out] l0_dstLocalTensor: L0A目的地址
 * @param[in] l1_srcLocalTensor: L1源地址
 * @param[in] l0a_mActualRound: 本次搬入L0A数据M维度大小（按正方形基块对齐）
 * @param[in] l0a_kActualRound_loop: 本次搬入L0A数据K维度大小（按正方形基块对齐）
 * @param[in] l1_srcStride: L1内同一小z行块中相邻小z首地址间隔（单位DataBlock）
 * @param[in] l0_baseblockM: L0内基块M维度大小
 * @param[in] l0_baseblockK: L0内基块K维度大小
 */
[aicore] inline __attribute__((always_inline)) void L12L0_Nz2Zz(
    AscendC::LocalTensor<half> l0_dstLocalTensor, 
    AscendC::LocalTensor<half> l1_srcLocalTensor, 
    uint32_t l0a_mActualRound, 
    uint32_t l0a_kActualRound_loop, 
    uint32_t l1_srcStride, 
    uint32_t l0_baseblockM, 
    uint32_t l0_baseblockK
){
    for(uint64_t i = 0; i < l0a_mActualRound/l0_baseblockM; i++ ){
        AscendC::LoadData(
            l0_dstLocalTensor[i * l0_baseblockM * l0a_kActualRound_loop], 
            l1_srcLocalTensor[i * l0_baseblockM * l0_baseblockK], 
            AscendC::LoadData2DParams(
                0, 
                l0a_kActualRound_loop / l0_baseblockK, 
                l1_srcStride / l0_baseblockM, 
                0, 
                0, 
                false, 
                0
            )
        );
    }

}

/**
 * @brief 用于 L1内Nz的A->L0A内Zz的A
 * @param[out] l0_dstLocalTensor: L0B目的地址
 * @param[in] l1_srcLocalTensor: L0A目的地址
 * @param[in] l0b_kActualRound_loop: 本次搬入L0A数据K维度大小（按正方形基块对齐）
 * @param[in] l0b_nActualRound: 本次搬入L0B数据N维度大小（按正方形基块对齐）
 * @param[in] l1_srcStride: L1内同一小z行块中相邻小z首地址间隔（单位DataBlock）
 * @param[in] l0_baseblockN: L0内基块N维度大小
 * @param[in] l0_baseblockK: L0内基块K维度大小
 */
[aicore] inline __attribute__((always_inline)) void L12L0_Nz2Zn(
    AscendC::LocalTensor<half> l0_dstLocalTensor, 
    AscendC::LocalTensor<half> l1_srcLocalTensor, 
    uint32_t l0b_kActualRound_loop, 
    uint32_t l0b_nActualRound, 
    uint32_t l1_srcStride, 
    uint32_t l0_baseblockK, 
    uint32_t l0_baseblockN
){
    for(uint64_t i = 0; i < l0b_kActualRound_loop/l0_baseblockK; i++){
        AscendC::LoadData(
            l0_dstLocalTensor[i * l0_baseblockK * l0b_nActualRound], 
            l1_srcLocalTensor[i * l0_baseblockK * l0_baseblockN], 
            AscendC::LoadData2DParams(
                0, 
                l0b_nActualRound / l0_baseblockN, 
                l1_srcStride / l0_baseblockK, 
                0, 
                0, 
                true, 
                0
            )
        );
    }
}

/**
 * @brief: 用于L0C内Nz格式AB结果块->GM内ND格式AB结果块
 * @param[out] gm_dstGlobalTensor: GM目的地址
 * @param[in] l0c_srcLocalTensor: L0C源地址
 * @param[in] l0c_mActual: L0C数据M维度有效长度
 * @param[in] l0c_nActual: L0C数据N维度有效长度
 * @param[in] l0a_mActualRound: L0A内M维度计算长度（按正方形基块对齐）
 * @param[in] l0b_nActualRound: L0B内N维度计算长度（按正方形基块对齐）
 * @param[in] gm_dstStride: GM搬运后相邻行首间隔
*/
[aicore] inline __attribute__((always_inline)) void L0C2Gm_Nz2Nd(
    AscendC::GlobalTensor<half> gm_dstGlobalTensor, 
    AscendC::LocalTensor<float> l0c_srcLocalTensor, 
    uint32_t l0c_mActual, 
    uint32_t l0c_nActual,
    uint32_t l0a_mActualRound, 
    uint32_t l0b_nActualRound,
    uint32_t gm_dstStride 
){
    AscendC::DataCopy(
        gm_dstGlobalTensor, 
        l0c_srcLocalTensor, 
        AscendC::DataCopyCO12DstParams(
            l0c_nActual, 
            l0c_mActual, 
            gm_dstStride, 
            l0a_mActualRound, 
            QuantMode_t::F322F16, 
            0, 
            false,
            true
        )
    );

}

/**
 * @brief: 
 * @param[in] elementType
 * @param[out] ub_dstLocalTensor
 * @param[in] gm_srcGlobalTensor
 * @param[in] gm_mActual
 * @param[in] gm_nActual
 * @param[in] ub_dstTailHeadStride: 目的数据尾首间隔，单位datablock
 * @param[in] gm_srcStride: 源数据首间隔，单位element
*/
template<typename elementType>
[aicore] inline __attribute__((always_inline)) void Gm2Ub(
    AscendC::LocalTensor<elementType> ub_dstLocalTensor, 
    AscendC::GlobalTensor<elementType> gm_srcGlobalTensor, 
    uint32_t gm_mActual, 
    uint32_t gm_nActual, 
    uint32_t ub_dstTailHeadStride, 
    uint32_t gm_srcStride, 
    uint32_t ub_leftPaddingSize, 
    uint32_t ub_rightPaddingSize
){
    AscendC::DataCopyPad<elementType>(
        ub_dstLocalTensor, 
        gm_srcGlobalTensor, 
        AscendC::DataCopyExtParams(
            gm_mActual, 
            gm_nActual * sizeof(elementType), 
            (gm_srcStride - gm_nActual) * sizeof(elementType), 
            ub_dstTailHeadStride,
            0
        ),
        AscendC::DataCopyPadExtParams<elementType>(
            false, 
            ub_leftPaddingSize, 
            ub_rightPaddingSize, 
            0
        )
    );
}

/**
 * @brief: 
 * @param[in] elementType
 * @param[out] gm_dstGlobalTensor
 * @param[in] ub_srcLocalTensor
 * @param[in] ub_mActual
 * @param[in] ub_nActual
 * @param[in] gm_dstStride: 目的数据首间隔，单位element
 * @param[in] ub_srcTailHeadStride: 源数据尾首间隔，单位datablock
*/
template<typename elementType>
[aicore] inline __attribute__((always_inline)) void Ub2Gm(
    AscendC::GlobalTensor<elementType> gm_dstGlobalTensor, 
    AscendC::LocalTensor<elementType> ub_srcLocalTensor, 
    uint32_t ub_mActual, 
    uint32_t ub_nActual, 
    uint32_t gm_dstStride, 
    uint32_t ub_srcTailHeadStride
){
    AscendC::DataCopyPad<elementType>(
        gm_dstGlobalTensor, 
        ub_srcLocalTensor, 
        AscendC::DataCopyExtParams(
            ub_mActual, 
            ub_nActual * sizeof(elementType), 
            ub_srcTailHeadStride, 
            (gm_dstStride - ub_nActual) * sizeof(elementType),
            0
        )
    );
}